import os

import cv2
import numpy as np
import tensorflow as tf

from .utils.io import read_text_file

flags = tf.app.flags
flags.DEFINE_string('data_dir', '',
                    'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('postfix', '', 'postfix of dataset')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
                                    'merged set.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('image_size', 256, 'size of input image')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
                                                          'difficult instances')
flags.DEFINE_boolean('channel_mean', False, 'Whether to compute channel mean value')
FLAGS = flags.FLAGS

SETS = ['train', 'val']


def main(_):
    if FLAGS.set not in SETS:
        raise ValueError('set must be in : {}'.format(SETS))

    data_dir = FLAGS.data_dir
    writer = tf.python_io.TFRecordWriter(
        os.path.join(FLAGS.output_path, '{}_{}.record'.format(FLAGS.set, FLAGS.postfix)))
    examples_path = os.path.join(data_dir, '{}_{}.txt'.format(FLAGS.set, FLAGS.postfix))
    examples_list = read_text_file(examples_path)
    total = len(examples_list)
    mean = np.zeros(3, np.float64)
    buffer_mean = np.zeros(3, np.float64)
    for idx, example in enumerate(examples_list):
        img_path, label = example.split('&!&')
        img = cv2.imread(img_path)
        if idx % 500 == 0:
            print('On image {} of {}'.format(idx, len(examples_list)))
            if FLAGS.channel_mean:
                mean += (buffer_mean * 500 / total)
                buffer_mean = np.zeros(3, np.float64)
        if FLAGS.channel_mean:
            buffer_mean += np.mean(img, axis=(0, 1))
        img_raw = img.tobytes()
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
            'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[img.shape[0]])),
            'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[img.shape[1]])),
            'channel': tf.train.Feature(int64_list=tf.train.Int64List(value=[img.shape[2]]))
        }))
        writer.write(example.SerializeToString())
    writer.close()
    print(mean)


if __name__ == '__main__':
    tf.app.run()
